/*
    Copyright 2006-2011 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/// \file
/// Definitions of MPI utility functions.

#include "mpi_util.h"
#include "aux_grid.h"
#include "random.h"
#include <boost/lexical_cast.hpp>
#include <blitz/array.h>

#ifdef HAVE_MPI

#include <boost/mpi.hpp>

using namespace std;
using namespace boost::mpi;

/** This implementation class does the actual mpienv work. This is the
    pimpl idiom, shielding clients from having to include any mpi
    stuff. If MPI isn't used, we revert to stub functions that do the
    right thing in non-MPI situations.*/
class mcrx::mpienv_impl {
private:
  friend class mpienv;

  /// We keep a global instance of a communicator to avoid having to
  /// create one every time
  communicator world_;

  mpienv_impl(int argc, char** argv);

public:
  ~mpienv_impl();
  bool is_mpi_master() const { return world_.rank()==0; };
  int mpi_rank() const { return world_.rank(); };
  int mpi_size() const { return world_.size(); };
  void barrier() { world_.barrier(); };
  communicator& world() { return world_; };
};

boost::shared_ptr<mcrx::mpienv_impl> mcrx::mpienv::pimpl;

/** The constructor initializes MPI. Note that we must use the C API
    to initialize threading because the boost::mpi::environment
    doesn't call MPI_Init_thread. */
mcrx::mpienv_impl::mpienv_impl(int argc, char** argv) {
  int mpi_threading_provided;
  MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &mpi_threading_provided); 
  if (mpi_threading_provided < MPI_THREAD_MULTIPLE) { 
    cerr << "Error: Your MPI implementation does not support MPI_THREAD_MULTIPLE.\n";
    MPI_Finalize();
    throw std::runtime_error("MPI does not support MPI_THREAD_MULTIPLE");
  }
}

mcrx::mpienv_impl::~mpienv_impl()
{
  MPI_Finalize(); 
} 

/** Initializes the MPI environment. Before this is called, the
    rank/size functions will return the values appropriate for a
    non-MPI run, so unless you are actually going to use MPI, you
    don't need to call this. */
void mcrx::mpienv::init(int argc, char** argv)
{
  assert(pimpl.get()==0);
  pimpl.reset(new mpienv_impl(argc, argv));
}

mcrx::mpienv_impl* mcrx::mpienv::instance()
{
  return pimpl.get(); 
}

/** This function returns true if the task is the MPI master task. */
bool mcrx::is_mpi_master()
{
  mpienv_impl* e(mpienv::instance());
  return e ? e->is_mpi_master() : true; 
}

/** This function returns the mpi rank. */
int mcrx::mpi_rank()
{
  mpienv_impl* e(mpienv::instance());
  return e ? e->mpi_rank() : 0;
}

string mcrx::mpi_rank_string()
{
  return boost::lexical_cast<string>(mpi_rank());
}

/** This function returns the number of mpi tasks. */
int mcrx::mpi_size()
{
  mpienv_impl* e(mpienv::instance());
  return e ? e->mpi_size() : 1;
}

void mcrx::mpi_barrier()
{
  mpienv_impl* e(mpienv::instance());
  if(e) e->barrier();
}


/* Performs an mpi reduction using the specified value and functor,
   returning the result on all tasks. */
template<typename T, typename T_red>
T mcrx::mpi_reduce(const T& val, const T_red& functor )
{
  const int sum_tag = 7410;
  communicator& world(mpienv::instance()->world());
  if(is_mpi_master()) {
    T result;
    reduce(world, val, result, functor, 0);
    broadcast(world, result, 0);
    return result;
      }
  else {
    reduce(world, val, functor, 0);
    T result;
    broadcast(world, result, 0);
    return result;
  }
}      

/* Performs an mpi broadcast, taking the specified value on task 0 and
   setting it on all tasks. */
template<typename T>
void mcrx::mpi_broadcast(T& val)
{
  communicator& world(mpienv::instance()->world());
  broadcast(world, val, 0);
}      

/** Collects a bunch of objects from the tasks into a vector on task 0. */  
template<typename T> 
vector<T> mcrx::mpi_collect(const T& val)
{
  const int collect_tag = 7414;
  communicator& world(mpienv::instance()->world());

  if(mpi_rank()==0) {
    vector<T> collection;
    collection.push_back(val);

    for(int t=1; t<mpi_size(); ++t) {
      T tmp;
      world.recv(t, collect_tag, tmp);
      collection.push_back(tmp);
    }
    return collection;
  }

  world.send(0, collect_tag, val);
  return vector<T>();
}


/** Sums the arrays across tasks, updating the array on the master
    task to be the sum. Unfortunately we can't use reduce with
    std::plus due to the fact that operator+ returns an expression,
    which Boost::MPI doesn't know what to do with. Instead, we just do
    an O(N) loop over all the tasks. This is unlikely to be a
    problem.  */
template<typename T_array>
void mcrx::mpi_sum_arrays(T_array& a)
{
  const int arr_sum_tag = 7411;
  communicator& world(mpienv::instance()->world());
  if(is_mpi_master()) {
    for(int i=1;i<mpi_size(); ++i) {
      T_array temp;
      world.recv(i, arr_sum_tag, temp);
      a+=temp;
    }
  }
  else {
    T_array temp(a);
    world.send(0, arr_sum_tag, temp);
  }
  world.barrier();
}

/** Stacks the arrays across tasks along the specified axis, returning
    the stacked result on the master task. This obviously only works
    if the master task has enough memory to fit the concatenated
    array. */
template<typename T_numtype, int N>
blitz::Array<T_numtype, N> 
mcrx::mpi_stack_arrays(const blitz::Array<T_numtype, N>& a, int axis)
{
  using namespace blitz;
  typedef blitz::Array<T_numtype, N> T_array;

  const int arr_stack_tag = 7412;
  communicator& world(mpienv::instance()->world());
  if(is_mpi_master()) {
    // get axis length
    int len;
    reduce(world, a.extent(axis), len, std::plus<int>(), 0);
    TinyVector<int, N> outshape=a.shape();
    outshape[axis]=len;
    T_array out(outshape);
    TinyVector<int, N> mask(0);
    mask[axis]=1;
    
    out(RectDomain<N>(a.lbound(), a.ubound())) = a;
    int start=a.extent(axis);
    for(int i=1;i<mpi_size(); ++i) {
      T_array temp;
      world.recv(i,arr_stack_tag,temp);
      out(RectDomain<N>(mask*start, temp.ubound()+mask*start))=temp;
      start+=temp.extent(axis);
    }
    return out;
  }
  else {
    reduce(world, a.extent(axis), std::plus<int>(), 0);
    world.send(0, arr_stack_tag, a);
    return blitz::Array<T_numtype, N>();
  }
  world.barrier();
}

/** Gathers a number from each task and returns the cumulative of
    those numbers. This is used to figure out which subset of a range
    is owned by each task. */
int mcrx::mpi_calculate_offsets(int i)
{
  communicator& world(mpienv::instance()->world());
  const int offset_tag=7413;
  if(mpi_rank()==0) {
    // on root task, gather numbers, calculate partial_sum and send it
    // out.
    vector<int> sizes;
    gather(world, i, sizes, 0);
    partial_sum(sizes.begin(), sizes.end(), sizes.begin());
    for(int i=1; i<world.size(); ++i)
      world.send(i,offset_tag, sizes[i-1]);
    return 0;
  }
  else {
    gather(world, i, 0);
    int offset;
    world.recv(0, offset_tag, offset);
    return offset;
  }
}

// explicit instantiation for commonly used types so we can insulate
// the users from needing mpi stuff.

template bool mcrx::mpi_reduce(const bool&, const std::logical_and<bool>&);
template bool mcrx::mpi_reduce(const bool&, const std::logical_or<bool>&);
template int mcrx::mpi_reduce(const int&, const std::plus<int>&);
template float mcrx::mpi_reduce(const float&, const std::plus<float>&);
template double mcrx::mpi_reduce(const double&, const std::plus<double>&);

template void mcrx::mpi_broadcast(bool&); 
template void mcrx::mpi_broadcast(int&); 
template void mcrx::mpi_broadcast(long&); 
template void mcrx::mpi_broadcast(float&); 
template void mcrx::mpi_broadcast(double&); 

template vector<bool> mcrx::mpi_collect(const bool&);
template vector<int> mcrx::mpi_collect(const int&);
template vector<double> mcrx::mpi_collect(const double&);
template vector<vector<mcrx::T_rng::T_state> > mcrx::mpi_collect(const vector<T_rng::T_state>&);

template void mcrx::mpi_sum_arrays(blitz::Array<float, 1>&);
template void mcrx::mpi_sum_arrays(blitz::Array<float, 2>&);
template void mcrx::mpi_sum_arrays(blitz::Array<float, 3>&);
template void mcrx::mpi_sum_arrays(blitz::Array<double, 1>&);
template void mcrx::mpi_sum_arrays(blitz::Array<double, 2>&);
template void mcrx::mpi_sum_arrays(blitz::Array<double, 3>&);
template void mcrx::mpi_sum_arrays(blitz::Array<mcrx::aux_pars_type, 1>&);
template void mcrx::mpi_sum_arrays(blitz::Array<mcrx::aux_pars_type, 2>&);
template void mcrx::mpi_sum_arrays(blitz::Array<mcrx::aux_pars_type, 3>&);

template blitz::Array<float, 1> mcrx::mpi_stack_arrays(const blitz::Array<float, 1>&, int);
template blitz::Array<float, 2> mcrx::mpi_stack_arrays(const blitz::Array<float, 2>&, int);
template blitz::Array<float, 3> mcrx::mpi_stack_arrays(const blitz::Array<float, 3>&, int);
template blitz::Array<double, 1> mcrx::mpi_stack_arrays(const blitz::Array<double, 1>&, int);
template blitz::Array<double, 2> mcrx::mpi_stack_arrays(const blitz::Array<double, 2>&, int);
template blitz::Array<double, 3> mcrx::mpi_stack_arrays(const blitz::Array<double, 3>&, int);


#endif

